library(here)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))

Introduction

  • Goal is to model diagnoses from each criteria based on how frequently they co-occur together within a single 10-point differential diagnosis iteration.
  • The rationale is that an ideal set of criteria will tend to result in a relatively similar set of diagnoses in each iteration and thus a small set of highly co-occurring diagnoses.
  • Conversely, a poor set of criteria will generate highly variable diagnoses in each iteration and rates of co-occurrence will be comparatively lower.

Import and process data

df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
# Tally each pair of co-occuring diagnoses within each criteria
# Original responses
df_gpt3.5_codiag <- create_codiagnosis_df(df_gpt3.5, remove_singletons = T)
df_gpt4.0_codiag <- create_codiagnosis_df(df_gpt4.0, remove_singletons = T)
df_claude3_haiku_t1.0_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0, remove_singletons = T)
df_claude3_opus_t1.0_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0, remove_singletons = T)
df_gemini1.0_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0, remove_singletons = T)
df_gemini1.5_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0, remove_singletons = T)
# ICD converted responses
df_gpt3.5_icd_codiag <- create_codiagnosis_df(df_gpt3.5_icd, remove_singletons = T)
df_gpt4.0_icd_codiag <- create_codiagnosis_df(df_gpt4.0_icd, remove_singletons = T)
df_claude3_haiku_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0_icd, remove_singletons = T)
df_claude3_opus_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0_icd, remove_singletons = T)
df_gemini1.0_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0_icd, remove_singletons = T)
df_gemini1.5_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0_icd, remove_singletons = T)
df_gpt4.0_codiag

Graph visualization

Exploring layouts

  • To determine clearest visualization of nodes and edges
# Selecting a layout
top_n <-  200
seed <- 1234

layouts <- c("fr", "dh", "kk", "stress", "graphopt")

graph_top_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)

for (l in layouts){
  set.seed(seed)
  plt <- centrality_graph(graph_top_gpt4, layout = l)
  plt <- plt + ggtitle("GPT4", subtitle = sprintf("Layout %s", l))
  print(plt)
}

Individual model data

Original responses

top_n <-  100
seed <- 321
graph_layout <- "stress"

codiag_graph_wrapper <- function(data){
  set.seed(seed)
  data <- make_codiagnosis_graph(data, n_diagnoses = top_n)
  centrality_graph(data, layout = graph_layout)
}

codiag_graph_wrapper(df_gpt3.5_codiag)+ ggtitle("ChatGPT 3.5", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_gpt4.0_codiag)+ ggtitle("ChatGPT 4.0", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_claude3_haiku_t1.0_codiag)+ ggtitle("Claude 3 Haiku", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_claude3_opus_t1.0_codiag)+ ggtitle("Claude 3 Opus", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_gemini1.0_pro_t1.0_codiag)+ ggtitle("Gemini 1.0 Pro", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_gemini1.5_pro_t1.0_codiag)+ ggtitle("Gemini 1.5 Pro", subtitle = sprintf("Top %s", top_n))

ICD converted respones

codiag_graph_wrapper(df_gpt3.5_icd_codiag)+ ggtitle("ChatGPT 3.5 ICD", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_gpt4.0_icd_codiag)+ ggtitle("ChatGPT 4.0 ICD", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_claude3_haiku_t1.0_icd_codiag)+ ggtitle("Claude 3 Haiku ICD", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_claude3_opus_t1.0_icd_codiag)+ ggtitle("Claude 3 Opus ICD", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_gemini1.0_pro_t1.0_icd_codiag)+ ggtitle("Gemini 1.0 Pro ICD", subtitle = sprintf("Top %s", top_n))

codiag_graph_wrapper(df_gemini1.5_pro_t1.0_icd_codiag)+ ggtitle("Gemini 1.5 Pro ICD", subtitle = sprintf("Top %s", top_n))

Combined model data

Original responses

set.seed(seed)
multi_make_codiagnosis_graph(threshold_method = "individual", top_n = 100, layout="stress",
                             df_gpt3.5_codiag, df_gpt4.0_codiag, 
                             df_claude3_haiku_t1.0_codiag, df_claude3_opus_t1.0_codiag,
                             df_gemini1.0_pro_t1.0_codiag, df_gemini1.5_pro_t1.0_codiag)

multi_make_codiagnosis_graph(threshold_method = "average", top_n = 100, layout="stress",
                             df_gpt3.5_codiag, df_gpt4.0_codiag, 
                             df_claude3_haiku_t1.0_codiag, df_claude3_opus_t1.0_codiag,
                             df_gemini1.0_pro_t1.0_codiag, df_gemini1.5_pro_t1.0_codiag)

ICD converted responses

set.seed(seed)
multi_make_codiagnosis_graph(threshold_method = "individual", top_n = 100, layout="stress",
                             df_gpt3.5_icd_codiag, df_gpt4.0_icd_codiag, 
                             df_claude3_haiku_t1.0_icd_codiag, df_claude3_opus_t1.0_icd_codiag,
                             df_gemini1.0_pro_t1.0_icd_codiag, df_gemini1.5_pro_t1.0_icd_codiag)

multi_make_codiagnosis_graph(threshold_method = "average", top_n = 100, layout="stress",
                             df_gpt3.5_icd_codiag, df_gpt4.0_icd_codiag, 
                             df_claude3_haiku_t1.0_icd_codiag, df_claude3_opus_t1.0_icd_codiag,
                             df_gemini1.0_pro_t1.0_icd_codiag, df_gemini1.5_pro_t1.0_icd_codiag)

Evaluate edge density

  • Edge density represents the total number of edges in a graph relative to the total number of possible edges in the graph
    • When all possible edges are present, edge density = 1
    • When no edges are present, edge density = 0
  • Dense co-occurrence networks represent criteria that generate highly reproducible diagnoses. A sparse co-occurrence represents a high degree of variability

Combined model data

Original responses

multi_edge_density_plot(
  df_gpt3.5_codiag,
  df_gpt4.0_codiag,
  df_claude3_haiku_t1.0_codiag,
  df_claude3_opus_t1.0_codiag,
  df_gemini1.0_pro_t1.0_codiag,
  df_gemini1.5_pro_t1.0_codiag
)  
Warning: The `fun.y` argument of `stat_summary()` is deprecated as of ggplot2 3.3.0.
Please use the `fun` argument instead.Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
Please use `linewidth` instead.Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.

ICD converted responses

plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.
plt_edge_icd

plt_edge_icd$data %>% summarise(mean(edge_density), .by="criteria")
extract_ggpubr_pvalues(plt_edge_icd)

Plot centrality

calculate_subgraph_centrality <- function(g, centrality_fun = "centrality_eigen"){
  data.frame(criteria = g %>% activate(edges) %>% pull(criteria) %>% unique()) %>% 
    mutate(sub_graphs = map(criteria, function(c){
      g %>% as_data_frame() %>% filter(criteria == c) %>% as_tbl_graph(directed = F) %>% 
        activate(nodes) %>% mutate(centrality = get(centrality_fun)()) %>% data.frame()
    })) %>% 
  unnest(sub_graphs) %>% 
  pivot_wider(names_from = "name", values_from = "centrality", values_fill = 0) %>% 
  column_to_rownames("criteria")
}

graph_all_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)
df_centr_gpt4 <- calculate_subgraph_centrality(graph_all_gpt4)
centrality_similarity <- function(data){
  data %>% 
    rownames_to_column("criteria") %>% 
    format_criteria() %>% 
    column_to_rownames("criteria") %>% 
    as.matrix() %>% 
    t() %>% 
    lsa::cosine() %>% 
    as.data.frame() %>% 
    rownames_to_column("V1") %>% 
    pivot_longer(-V1, names_to = "V2", values_to = "cosine")
}

centrality_similarity(df_centr_gpt4)
centrality_wrapper <- function(data, n_diag=NULL){
  make_codiagnosis_graph(data, n_diagnoses = n_diag) %>% 
    calculate_subgraph_centrality() %>% 
    centrality_similarity()
}

average_cosine_matrix <- listN(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
) %>%
  enframe(name = "model", value = "data") %>%
  mutate(data = map(data, centrality_wrapper)) %>%
  unnest(data) %>%
  dplyr::summarise(cosine = mean(cosine), .by = c("V1", "V2")) %>%
  pivot_wider(names_from = "V2", values_from = "cosine") %>%
  column_to_rownames("V1") %>%
  as.matrix()

average_cosine_matrix
                   SLE - SLICC SLE - EULAR-ACR MCAS - Consortium MCAS - Alternative
SLE - SLICC          1.0000000       0.8540815         0.2276693          0.5103618
SLE - EULAR-ACR      0.8540815       1.0000000         0.2682776          0.4819197
MCAS - Consortium    0.2276693       0.2682776         1.0000000          0.5459547
MCAS - Alternative   0.5103618       0.4819197         0.5459547          1.0000000
custom_heatmap <- function(data, 
                           plot_title=NULL, 
                           legend_title=NULL, 
                           color_scale = NULL, 
                           midpoint = NULL, 
                           symmetric = T, 
                           matrix_title_size=10,
                           matrix_names_size=8,
                           legend_title_size=10,
                           legend_label_size=8,
                           dendrograms=T,
                           dendrogram_weight = unit(10, "mm"),
                           legend_orientation = "vertical",
                           legend_length=NULL,
                           grid_lines=F
                           ){
  
  # Determin midpoint of data
  scale_max <- ifelse(symmetric,max(abs(data)),max(data))
  scale_min <- ifelse(symmetric,-max(abs(data)),min(data))
  scale_mid <- scale_min + (scale_max - scale_min)/2
  
  midpoint <- ifelse(is.null(midpoint), scale_mid, midpoint)
  
  # Set default colorscales based on symmetry of data
  if (is.null(color_scale) & symmetric){color_scale <- hcl.colors(3, "Earth")}
  if (is.null(color_scale) & !symmetric){color_scale <- viridis::viridis(3)}
  color_function <-circlize::colorRamp2(c(scale_min,midpoint,scale_max), color_scale)
  
  # Legend parameters
  legend_params <- list(
      "title_gp" = grid::gpar(fontsize = legend_title_size, fontface = "bold"),
      "labels_gp" = grid::gpar(fontsize = legend_label_size),
      "direction" = legend_orientation
      )
  
  # Heatmap parameters
  heatmap_arguments <- list(
    "matrix" = data,
    "col" = color_function,
    "row_names_gp" = grid::gpar(fontsize = matrix_names_size),
    "column_names_gp" = grid::gpar(fontsize = matrix_names_size),
    "show_column_dend"=dendrograms,
    "show_row_dend"=dendrograms
  )
  
  if(!is.null(legend_title)){heatmap_arguments[["name"]] <- legend_title}
  if(grid_lines){heatmap_arguments[["rect_gp"]] <- grid::gpar(col = "black", lwd = 1)}
  if(dendrograms){
    heatmap_arguments[['column_dend_height']] <- dendrogram_weight
    heatmap_arguments[['row_dend_width']] <- dendrogram_weight
  }
  
  legend_side <- ifelse(legend_orientation=="vertical","right","bottom")
  if(legend_orientation=="vertical" &!is.null(legend_length)){
    legend_params[['legend_height']] <- legend_length}
  if(legend_orientation=="horizontal" &!is.null(legend_length)){
    legend_params[['legend_width']] <- legend_length}
  # Function call
  heatmap_arguments[["heatmap_legend_param"]] <- legend_params
  ht <- do.call(Heatmap, heatmap_arguments)
  draw(ht, heatmap_legend_side=legend_side, align_heatmap_legend = "global_center")
  
}

custom_heatmap(average_cosine_matrix, symmetric = F, legend_title = "Cosine similarity", grid_lines = T, dendrograms = F, legend_orientation = "horizontal", legend_length=unit(10, "cm"))

df_comp <- listN(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
) %>%
  enframe(name = "model", value = "data") %>%
  mutate(data = map(data, centrality_wrapper)) %>%
  unnest(data)
cosine_similarity_compare <- function(df, pt_size=1,p_size=3){
  # Format data
  df <- df %>% 
    unite(comp, V1, V2) %>% 
    filter(comp %in% c("MCAS - Consortium_MCAS - Alternative", "SLE - EULAR-ACR_SLE - SLICC")) %>% 
    mutate(comp = gsub("_", "\nvs. ", comp)) %>%  
    format_models() 
  
  # Plot data
  df %>% 
    ggplot(aes(x = comp, y = cosine, color = model))+
    geom_point(size=pt_size, position = position_dodge(width = 0.75))+
    theme_bw() +
    ggpubr::stat_compare_means(aes(group = comp), method = "wilcox.test", label = "p", vjust = 0.75, show.legend = F, size = p_size)+
    theme(axis.text.x = element_text(angle =90, hjust =1)) +
    labs(x=NULL, y="Cosine similarity") +
    scale_color_brewer(palette = "Dark2") +
    labs(color = "")
}
cosine_similarity_compare(df_comp)

individual_cosine_matrix <- combine_data_frames(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag,
  additional_function = function(x){
    x %>% 
      make_codiagnosis_graph() %>% 
      calculate_subgraph_centrality() %>% 
      rownames_to_column("criteria") %>% 
      pivot_longer(-criteria, names_to="diagnosis", values_to="centrality")
  }
) %>% 
  unite(criteria, original_df, criteria, sep = "-") %>% 
  pivot_wider(names_from = "diagnosis", values_from="centrality", values_fill=0) %>% 
  column_to_rownames("criteria") %>% 
  as.matrix() %>% 
    t() %>% 
    lsa::cosine() %>% 
    as.data.frame() %>% 
    rownames_to_column("V1") %>% 
    pivot_longer(-V1, names_to = "V2", values_to = "cosine") %>% 
  pivot_wider(names_from = "V2", values_from = "cosine") %>%
  column_to_rownames("V1") %>%
  as.matrix()

individual_cosine_matrix[1:5, 1:5]
                                      df_gpt3.5_icd_codiag-slicc_sle df_gpt3.5_icd_codiag-eular_acr_sle df_gpt3.5_icd_codiag-mcas_consortium df_gpt3.5_icd_codiag-mcas_alternative df_gpt4.0_icd_codiag-slicc_sle
df_gpt3.5_icd_codiag-slicc_sle                             1.0000000                          0.8647922                            0.2136086                             0.4720033                      0.8522184
df_gpt3.5_icd_codiag-eular_acr_sle                         0.8647922                          1.0000000                            0.2643814                             0.4613926                      0.7324445
df_gpt3.5_icd_codiag-mcas_consortium                       0.2136086                          0.2643814                            1.0000000                             0.5519625                      0.1663721
df_gpt3.5_icd_codiag-mcas_alternative                      0.4720033                          0.4613926                            0.5519625                             1.0000000                      0.3913603
df_gpt4.0_icd_codiag-slicc_sle                             0.8522184                          0.7324445                            0.1663721                             0.3913603                      1.0000000
alt_heatmap <- model_criteria_heatmap(individual_cosine_matrix, 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Cosine\nsimilarity", 
                symmetric = F,
                font_size = 8)  
  

alt_heatmap

plt_alt <- cowplot::plot_grid(grid::grid.grabExpr(ComplexHeatmap::draw(alt_heatmap)))
plt_alt

ggsave(here("figures/4_alt_heatmap.pdf"), plot=plt_alt, height = 4, width = 6)

Final plot

Version 1

# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad+2,
      b = legend_y_pad,
      l = legend_x_pad 
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(1.0,"pt"),
    axis.text.x=element_text(angle=45,hjust=1)
  )
)


strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

set.seed(1234)
plt_graph_icd <-
  multi_make_codiagnosis_graph(
    threshold_method = "average",
    top_n = 100,
    layout = "stress",
    df_gpt3.5_codiag,
    df_gpt4.0_codiag,
    df_claude3_haiku_t1.0_codiag,
    df_claude3_opus_t1.0_codiag,
    df_gemini1.0_pro_t1.0_codiag,
    df_gemini1.5_pro_t1.0_codiag,
    point_size = 1.25,
    border_size = 0.25,
    edge_width = 0.5,
    edge_alpha = 0.5,
    label_text_size = 9,
    tick_text_size = 6,
    highlight_stroke_multiplier = 3,
    legend_height = unit(25, "pt"),
    legend_width = unit(10, "pt")
  ) 

plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
  scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))
Scale for colour is already present.
Adding another scale for colour, which will replace the existing scale.
plt_heatmap_icd <-
  custom_heatmap(
    average_cosine_matrix,
    symmetric = F,
    legend_title = "Cosine similarity",
    grid_lines = T,
    dendrograms = T,
    legend_orientation = "horizontal",
    legend_length = unit(2, "cm"),
    matrix_names_size = 6, 
    legend_title_size = 7.5,
    legend_label_size = 6,
    dendrogram_weight = unit(2, "mm")
  )



plt_fig <- cowplot::plot_grid(
  NULL, 
  cowplot::plot_grid(
    NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
  ),
  NULL,
  cowplot::plot_grid(
    cowplot::plot_grid(NULL, plt_edge_icd ,rel_heights = c(0.1, 1), ncol = 1),
    NULL,
    cowplot::plot_grid(
      NULL,
      grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom")),
      NULL,
      ncol = 1,
      rel_heights = c(0.1, 1, 0.1)
    ),
    nrow = 1,
    rel_widths = c(0.5, 0.1, 0.6),
    labels = c("B","","C"),
    axis = 'h', align = 'bt'
  ),

  ncol = 1,
  rel_heights = c(0.05, 0.95, 0.01, 1),
  labels = c("A")
)

plt_fig 

Version 2

ggsave(here("figures/4_Network_analysis.pdf"), plot=plt_fig, height = 5.5, width = 3.5)
---
title: "Co-diagnosis network analysis"
output: 
  html_notebook:
    toc: true
    toc_float: true
---

```{r, message = F}
library(here)
source(here("utils/data_processing.R"))
source(here("utils/figures.R"))
```

# Introduction

- Goal is to model diagnoses from each criteria based on how frequently they co-occur together within a single 10-point differential diagnosis iteration.
- The rationale is that an ideal set of criteria will tend to result in a relatively similar set of diagnoses in each iteration and thus a small set of highly co-occurring diagnoses.
- Conversely, a poor set of criteria will generate highly variable diagnoses in each iteration and rates of co-occurrence will be comparatively lower.

# Import and process data

```{r, message = F}
df_gpt3.5 <- read_model("gpt-3.5-turbo-1106", icd = FALSE)
df_gpt4.0 <- read_model("gpt-4-turbo-preview", icd = FALSE)
df_claude3_haiku_t1.0 <- read_model("claude-3-haiku-20240307_t1-0", icd = FALSE)
df_claude3_opus_t1.0 <- read_model("claude-3-opus-20240229_t1-0", icd = FALSE)
df_gemini1.0_pro_t1.0 <- read_model("gemini-1.0-pro-002_t1-0", icd = FALSE)
df_gemini1.5_pro_t1.0 <- read_model("gemini-1.5-pro-001_t1-0", icd = FALSE)
```

```{r, message = F}
df_gpt3.5_icd <- read_model("gpt-3.5-turbo-1106", icd = TRUE)
df_gpt4.0_icd <- read_model("gpt-4-turbo-preview", icd = TRUE)
df_claude3_haiku_t1.0_icd <- read_model("claude-3-haiku-20240307_t1-0", icd = TRUE)
df_claude3_opus_t1.0_icd <- read_model("claude-3-opus-20240229_t1-0", icd = TRUE)
df_gemini1.0_pro_t1.0_icd <- read_model("gemini-1.0-pro-002_t1-0", icd = TRUE)
df_gemini1.5_pro_t1.0_icd <- read_model("gemini-1.5-pro-001_t1-0", icd = TRUE)
```

```{r}
# Tally each pair of co-occuring diagnoses within each criteria
# Original responses
df_gpt3.5_codiag <- create_codiagnosis_df(df_gpt3.5, remove_singletons = T)
df_gpt4.0_codiag <- create_codiagnosis_df(df_gpt4.0, remove_singletons = T)
df_claude3_haiku_t1.0_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0, remove_singletons = T)
df_claude3_opus_t1.0_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0, remove_singletons = T)
df_gemini1.0_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0, remove_singletons = T)
df_gemini1.5_pro_t1.0_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0, remove_singletons = T)
```


```{r}
# ICD converted responses
df_gpt3.5_icd_codiag <- create_codiagnosis_df(df_gpt3.5_icd, remove_singletons = T)
df_gpt4.0_icd_codiag <- create_codiagnosis_df(df_gpt4.0_icd, remove_singletons = T)
df_claude3_haiku_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_haiku_t1.0_icd, remove_singletons = T)
df_claude3_opus_t1.0_icd_codiag <- create_codiagnosis_df(df_claude3_opus_t1.0_icd, remove_singletons = T)
df_gemini1.0_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.0_pro_t1.0_icd, remove_singletons = T)
df_gemini1.5_pro_t1.0_icd_codiag <- create_codiagnosis_df(df_gemini1.5_pro_t1.0_icd, remove_singletons = T)
```


```{r}
df_gpt4.0_codiag
```

# Graph visualization

## Exploring layouts 

- To determine clearest visualization of nodes and edges

```{r, fig.width=8, fig.height=6, warning=F, message=F}
# Selecting a layout
top_n <-  200
seed <- 1234

layouts <- c("fr", "dh", "kk", "stress", "graphopt")

graph_top_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)

for (l in layouts){
  set.seed(seed)
  plt <- centrality_graph(graph_top_gpt4, layout = l)
  plt <- plt + ggtitle("GPT4", subtitle = sprintf("Layout %s", l))
  print(plt)
}

```
## Individual model data

### Original responses

```{r, fig.width=8, fig.height=6, warning=F, message=F}
top_n <-  100
seed <- 321
graph_layout <- "stress"

codiag_graph_wrapper <- function(data){
  set.seed(seed)
  data <- make_codiagnosis_graph(data, n_diagnoses = top_n)
  centrality_graph(data, layout = graph_layout)
}

codiag_graph_wrapper(df_gpt3.5_codiag)+ ggtitle("ChatGPT 3.5", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gpt4.0_codiag)+ ggtitle("ChatGPT 4.0", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_haiku_t1.0_codiag)+ ggtitle("Claude 3 Haiku", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_opus_t1.0_codiag)+ ggtitle("Claude 3 Opus", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.0_pro_t1.0_codiag)+ ggtitle("Gemini 1.0 Pro", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.5_pro_t1.0_codiag)+ ggtitle("Gemini 1.5 Pro", subtitle = sprintf("Top %s", top_n))
```
### ICD converted respones

```{r, fig.width=8, fig.height=6, warning=F, message=F}
codiag_graph_wrapper(df_gpt3.5_icd_codiag)+ ggtitle("ChatGPT 3.5 ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gpt4.0_icd_codiag)+ ggtitle("ChatGPT 4.0 ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_haiku_t1.0_icd_codiag)+ ggtitle("Claude 3 Haiku ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_claude3_opus_t1.0_icd_codiag)+ ggtitle("Claude 3 Opus ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.0_pro_t1.0_icd_codiag)+ ggtitle("Gemini 1.0 Pro ICD", subtitle = sprintf("Top %s", top_n))
codiag_graph_wrapper(df_gemini1.5_pro_t1.0_icd_codiag)+ ggtitle("Gemini 1.5 Pro ICD", subtitle = sprintf("Top %s", top_n))
```

## Combined model data

### Original responses

```{r, fig.width=8, fig.height=6, warning=F, message=F}
set.seed(seed)
multi_make_codiagnosis_graph(threshold_method = "individual", top_n = 100, layout="stress",
                             df_gpt3.5_codiag, df_gpt4.0_codiag, 
                             df_claude3_haiku_t1.0_codiag, df_claude3_opus_t1.0_codiag,
                             df_gemini1.0_pro_t1.0_codiag, df_gemini1.5_pro_t1.0_codiag)
```

```{r, fig.width=8, fig.height=6, warning=F, message=F}
multi_make_codiagnosis_graph(threshold_method = "average", top_n = 100, layout="stress",
                             df_gpt3.5_codiag, df_gpt4.0_codiag, 
                             df_claude3_haiku_t1.0_codiag, df_claude3_opus_t1.0_codiag,
                             df_gemini1.0_pro_t1.0_codiag, df_gemini1.5_pro_t1.0_codiag)
```
### ICD converted responses

```{r, fig.width=8, fig.height=6, warning=F, message=F}
set.seed(seed)
multi_make_codiagnosis_graph(threshold_method = "individual", top_n = 100, layout="stress",
                             df_gpt3.5_icd_codiag, df_gpt4.0_icd_codiag, 
                             df_claude3_haiku_t1.0_icd_codiag, df_claude3_opus_t1.0_icd_codiag,
                             df_gemini1.0_pro_t1.0_icd_codiag, df_gemini1.5_pro_t1.0_icd_codiag)
```


```{r, fig.width=8, fig.height=6, warning=F, message=F}
multi_make_codiagnosis_graph(threshold_method = "average", top_n = 100, layout="stress",
                             df_gpt3.5_icd_codiag, df_gpt4.0_icd_codiag, 
                             df_claude3_haiku_t1.0_icd_codiag, df_claude3_opus_t1.0_icd_codiag,
                             df_gemini1.0_pro_t1.0_icd_codiag, df_gemini1.5_pro_t1.0_icd_codiag)
```

# Evaluate edge density

- Edge density represents the total number of edges in a graph relative to the total number of **possible** edges in the graph
  - When all possible edges are present, edge density = 1
  - When no edges are present, edge density = 0
- Dense co-occurrence networks represent criteria that generate highly reproducible diagnoses. A sparse co-occurrence represents a high degree of variability 

## Combined model data

### Original responses

```{r, fig.width=4, fig.height=3.5}
multi_edge_density_plot(
  df_gpt3.5_codiag,
  df_gpt4.0_codiag,
  df_claude3_haiku_t1.0_codiag,
  df_claude3_opus_t1.0_codiag,
  df_gemini1.0_pro_t1.0_codiag,
  df_gemini1.5_pro_t1.0_codiag
)  

```

### ICD converted responses

```{r, fig.width=4, fig.height=3.5}
plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  
plt_edge_icd
plt_edge_icd$data %>% summarise(mean(edge_density), .by="criteria")
extract_ggpubr_pvalues(plt_edge_icd)
```



# Plot centrality
```{r}
calculate_subgraph_centrality <- function(g, centrality_fun = "centrality_eigen"){
  data.frame(criteria = g %>% activate(edges) %>% pull(criteria) %>% unique()) %>% 
    mutate(sub_graphs = map(criteria, function(c){
      g %>% as_data_frame() %>% filter(criteria == c) %>% as_tbl_graph(directed = F) %>% 
        activate(nodes) %>% mutate(centrality = get(centrality_fun)()) %>% data.frame()
    })) %>% 
  unnest(sub_graphs) %>% 
  pivot_wider(names_from = "name", values_from = "centrality", values_fill = 0) %>% 
  column_to_rownames("criteria")
}

graph_all_gpt4 <- make_codiagnosis_graph(df_gpt4.0_codiag, n_diagnoses = top_n)
df_centr_gpt4 <- calculate_subgraph_centrality(graph_all_gpt4)
```

```{r}
centrality_similarity <- function(data){
  data %>% 
    rownames_to_column("criteria") %>% 
    format_criteria() %>% 
    column_to_rownames("criteria") %>% 
    as.matrix() %>% 
    t() %>% 
    lsa::cosine() %>% 
    as.data.frame() %>% 
    rownames_to_column("V1") %>% 
    pivot_longer(-V1, names_to = "V2", values_to = "cosine")
}

centrality_similarity(df_centr_gpt4)
```
```{r}
centrality_wrapper <- function(data, n_diag=NULL){
  make_codiagnosis_graph(data, n_diagnoses = n_diag) %>% 
    calculate_subgraph_centrality() %>% 
    centrality_similarity()
}

average_cosine_matrix <- listN(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
) %>%
  enframe(name = "model", value = "data") %>%
  mutate(data = map(data, centrality_wrapper)) %>%
  unnest(data) %>%
  dplyr::summarise(cosine = mean(cosine), .by = c("V1", "V2")) %>%
  pivot_wider(names_from = "V2", values_from = "cosine") %>%
  column_to_rownames("V1") %>%
  as.matrix()

average_cosine_matrix


```



```{r, fig.width=5.5, fig.height=3.5}
custom_heatmap <- function(data, 
                           plot_title=NULL, 
                           legend_title=NULL, 
                           color_scale = NULL, 
                           midpoint = NULL, 
                           symmetric = T, 
                           matrix_title_size=10,
                           matrix_names_size=8,
                           legend_title_size=10,
                           legend_label_size=8,
                           dendrograms=T,
                           dendrogram_weight = unit(10, "mm"),
                           legend_orientation = "vertical",
                           legend_length=NULL,
                           grid_lines=F
                           ){
  
  # Determin midpoint of data
  scale_max <- ifelse(symmetric,max(abs(data)),max(data))
  scale_min <- ifelse(symmetric,-max(abs(data)),min(data))
  scale_mid <- scale_min + (scale_max - scale_min)/2
  
  midpoint <- ifelse(is.null(midpoint), scale_mid, midpoint)
  
  # Set default colorscales based on symmetry of data
  if (is.null(color_scale) & symmetric){color_scale <- hcl.colors(3, "Earth")}
  if (is.null(color_scale) & !symmetric){color_scale <- viridis::viridis(3)}
  color_function <-circlize::colorRamp2(c(scale_min,midpoint,scale_max), color_scale)
  
  # Legend parameters
  legend_params <- list(
      "title_gp" = grid::gpar(fontsize = legend_title_size, fontface = "bold"),
      "labels_gp" = grid::gpar(fontsize = legend_label_size),
      "direction" = legend_orientation
      )
  
  # Heatmap parameters
  heatmap_arguments <- list(
    "matrix" = data,
    "col" = color_function,
    "row_names_gp" = grid::gpar(fontsize = matrix_names_size),
    "column_names_gp" = grid::gpar(fontsize = matrix_names_size),
    "show_column_dend"=dendrograms,
    "show_row_dend"=dendrograms
  )
  
  if(!is.null(legend_title)){heatmap_arguments[["name"]] <- legend_title}
  if(grid_lines){heatmap_arguments[["rect_gp"]] <- grid::gpar(col = "black", lwd = 1)}
  if(dendrograms){
    heatmap_arguments[['column_dend_height']] <- dendrogram_weight
    heatmap_arguments[['row_dend_width']] <- dendrogram_weight
  }
  
  legend_side <- ifelse(legend_orientation=="vertical","right","bottom")
  if(legend_orientation=="vertical" &!is.null(legend_length)){
    legend_params[['legend_height']] <- legend_length}
  if(legend_orientation=="horizontal" &!is.null(legend_length)){
    legend_params[['legend_width']] <- legend_length}
  # Function call
  heatmap_arguments[["heatmap_legend_param"]] <- legend_params
  ht <- do.call(Heatmap, heatmap_arguments)
  draw(ht, heatmap_legend_side=legend_side, align_heatmap_legend = "global_center")
  
}

custom_heatmap(average_cosine_matrix, symmetric = F, legend_title = "Cosine similarity", grid_lines = T, dendrograms = F, legend_orientation = "horizontal", legend_length=unit(10, "cm"))
```
```{r}
df_comp <- listN(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
) %>%
  enframe(name = "model", value = "data") %>%
  mutate(data = map(data, centrality_wrapper)) %>%
  unnest(data)
```
```{r, fig.width=3.5, fig.height=4}
cosine_similarity_compare <- function(df, pt_size=1,p_size=3){
  # Format data
  df <- df %>% 
    unite(comp, V1, V2) %>% 
    filter(comp %in% c("MCAS - Consortium_MCAS - Alternative", "SLE - EULAR-ACR_SLE - SLICC")) %>% 
    mutate(comp = gsub("_", "\nvs. ", comp)) %>%  
    format_models() 
  
  # Plot data
  df %>% 
    ggplot(aes(x = comp, y = cosine, color = model))+
    geom_point(size=pt_size, position = position_dodge(width = 0.75))+
    theme_bw() +
    ggpubr::stat_compare_means(aes(group = comp), method = "wilcox.test", label = "p", vjust = 0.75, show.legend = F, size = p_size)+
    theme(axis.text.x = element_text(angle =90, hjust =1)) +
    labs(x=NULL, y="Cosine similarity") +
    scale_color_brewer(palette = "Dark2") +
    labs(color = "")
}
cosine_similarity_compare(df_comp)
```

```{r}
individual_cosine_matrix <- combine_data_frames(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag,
  additional_function = function(x){
    x %>% 
      make_codiagnosis_graph() %>% 
      calculate_subgraph_centrality() %>% 
      rownames_to_column("criteria") %>% 
      pivot_longer(-criteria, names_to="diagnosis", values_to="centrality")
  }
) %>% 
  unite(criteria, original_df, criteria, sep = "-") %>% 
  pivot_wider(names_from = "diagnosis", values_from="centrality", values_fill=0) %>% 
  column_to_rownames("criteria") %>% 
  as.matrix() %>% 
    t() %>% 
    lsa::cosine() %>% 
    as.data.frame() %>% 
    rownames_to_column("V1") %>% 
    pivot_longer(-V1, names_to = "V2", values_to = "cosine") %>% 
  pivot_wider(names_from = "V2", values_from = "cosine") %>%
  column_to_rownames("V1") %>%
  as.matrix()

individual_cosine_matrix[1:5, 1:5]
```

```{r, fig.width=5.5, fig.height=3.5}
alt_heatmap <- model_criteria_heatmap(individual_cosine_matrix, 
                color_scale = viridis::viridis(3), 
                title = " ", 
                metric = "Cosine\nsimilarity", 
                symmetric = F,
                font_size = 8)  
  

alt_heatmap
```
```{r, fig.width=6, fig.height=4}
plt_alt <- cowplot::plot_grid(grid::grid.grabExpr(ComplexHeatmap::draw(alt_heatmap)))
plt_alt
```
```{r}
ggsave(here("figures/4_alt_heatmap.pdf"), plot=plt_alt, height = 4, width = 6)
```

# Final plot

### Version 1
```{r, fig.width=3.5, fig.height=5.5}
# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad+2,
      b = legend_y_pad,
      l = legend_x_pad 
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(1.0,"pt"),
    axis.text.x=element_text(angle=45,hjust=1)
  )
)


strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

set.seed(1234)
plt_graph_icd <-
  multi_make_codiagnosis_graph(
    threshold_method = "average",
    top_n = 100,
    layout = "stress",
    df_gpt3.5_codiag,
    df_gpt4.0_codiag,
    df_claude3_haiku_t1.0_codiag,
    df_claude3_opus_t1.0_codiag,
    df_gemini1.0_pro_t1.0_codiag,
    df_gemini1.5_pro_t1.0_codiag,
    point_size = 1.25,
    border_size = 0.25,
    edge_width = 0.5,
    edge_alpha = 0.5,
    label_text_size = 9,
    tick_text_size = 6,
    highlight_stroke_multiplier = 3,
    legend_height = unit(25, "pt"),
    legend_width = unit(10, "pt")
  ) 

plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  +
  apply_text_formatting +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
  scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))

plt_heatmap_icd <-
  custom_heatmap(
    average_cosine_matrix,
    symmetric = F,
    legend_title = "Cosine similarity",
    grid_lines = T,
    dendrograms = T,
    legend_orientation = "horizontal",
    legend_length = unit(2, "cm"),
    matrix_names_size = 6, 
    legend_title_size = 7.5,
    legend_label_size = 6,
    dendrogram_weight = unit(2, "mm")
  )


plt_fig <- cowplot::plot_grid(
  NULL, 
  cowplot::plot_grid(
    NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
  ),
  NULL,
  cowplot::plot_grid(
    cowplot::plot_grid(NULL, plt_edge_icd ,rel_heights = c(0.1, 1), ncol = 1),
    NULL,
    cowplot::plot_grid(
      NULL,
      grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom")),
      NULL,
      ncol = 1,
      rel_heights = c(0.1, 1, 0.1)
    ),
    nrow = 1,
    rel_widths = c(0.5, 0.1, 0.6),
    labels = c("B","","C"),
    axis = 'h', align = 'bt'
  ),

  ncol = 1,
  rel_heights = c(0.05, 0.95, 0.01, 1),
  labels = c("A")
)

plt_fig 
```

### Version 2
```{r, fig.width=3.5, fig.height=5.5}
# n_diagnoses_bar <- 10
# n_diagnoses_abundance <- 50
# n_diagnoses_cumulative <- 50

title_size <- 9
label_size <- 6
legend_x_pad <- 0
legend_y_pad <- 0

apply_text_formatting <- list(
  theme(
    axis.text = element_text(size = label_size),
    axis.title = element_text(size = title_size),
    legend.text = element_text(size = label_size, margin=margin(0,0,0,r=1)),
    strip.text = element_text(size = label_size + 1),
    legend.key.height = unit(0.4, 'cm'),
    legend.key.width = unit(0.4, 'cm'),
    # legend.key = element_rect(size =  margin(0,0,0,0)),
    legend.box.background = element_rect(color = "black", size = 1),
    legend.margin = margin(
      t = legend_y_pad,
      r = legend_x_pad+2,
      b = legend_y_pad,
      l = legend_x_pad 
    ),
    legend.key.spacing.y = unit(-1.5, "pt"),
    legend.box.spacing = unit(1.0,"pt"),
    axis.text.x=element_text(angle=90, vjust=0.5)
  )
)


strip_margin <- 1
strip_formatting <- list(theme(
  strip.text.x = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin)), 
  strip.text.y = element_text(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
  # strip.background = element_rect(margin = margin(t = strip_margin, r = strip_margin, b = strip_margin, l = strip_margin))
))

set.seed(1234)
plt_graph_icd <-
  multi_make_codiagnosis_graph(
    threshold_method = "average",
    top_n = 100,
    layout = "stress",
    df_gpt3.5_codiag,
    df_gpt4.0_codiag,
    df_claude3_haiku_t1.0_codiag,
    df_claude3_opus_t1.0_codiag,
    df_gemini1.0_pro_t1.0_codiag,
    df_gemini1.5_pro_t1.0_codiag,
    point_size = 1.25,
    border_size = 0.25,
    edge_width = 0.5,
    edge_alpha = 0.5,
    label_text_size = 9,
    tick_text_size = 6,
    highlight_stroke_multiplier = 3,
    legend_height = unit(25, "pt"),
    legend_width = unit(10, "pt")
  ) 

plt_edge_icd <- multi_edge_density_plot(
  df_gpt3.5_icd_codiag,
  df_gpt4.0_icd_codiag,
  df_claude3_haiku_t1.0_icd_codiag,
  df_claude3_opus_t1.0_icd_codiag,
  df_gemini1.0_pro_t1.0_icd_codiag,
  df_gemini1.5_pro_t1.0_icd_codiag
)  +
  apply_text_formatting +
  # theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(color = guide_legend(override.aes = list(size = 1), ncol = 2))+
  scale_y_continuous(expand = expansion(mult = c(0.1,0.1)))

plt_cosine <- cosine_similarity_compare(df_comp, p_size = 2) +
  apply_text_formatting

plt_heatmap_icd <-
  custom_heatmap(
    average_cosine_matrix,
    symmetric = F,
    legend_title = "Cosine similarity",
    grid_lines = T,
    dendrograms = T,
    legend_orientation = "horizontal",
    legend_length = unit(2.5, "cm"),
    matrix_names_size = 6, 
    legend_title_size = 7.5,
    legend_label_size = 6,
    dendrogram_weight = unit(2.5, "mm")
  )

pd <- -1

plt_fig <- cowplot::plot_grid(
  #1
  NULL, 
  #2
  cowplot::plot_grid(
    NULL, plt_graph_icd, nrow=1, rel_widths=c(0.05,0.9)
  ),
  #3
  NULL,
  #4
  cowplot::plot_grid(
    #4.1
    cowplot::plot_grid(
      #4.1.1
      plot_grid(
        #4.1.1.1
        plt_edge_icd+theme(legend.position = "none"),
        #4.1.1.2
        plt_cosine+theme(legend.position="none")+ylim(0.4,1), 
        # rel_widths = c(10,9),
        nrow=1,
        align="h",axis="bt"
        ),
      #4.1.2
      NA,
      plot_grid(NULL,get_legend(plt_edge_icd),nrow=1,rel_widths=c(0.15,1)),
      NA,
      ncol=1,
      rel_heights = c(1, 0.04, 0.1,0.1) #4.1._
    ),
    #4.2
    cowplot::plot_grid(
      #4.2.1
      NULL,
      #4.2.2
      grid::grid.grabExpr(ComplexHeatmap::draw(plt_heatmap_icd, heatmap_legend_side = "bottom", padding = unit(c(0,0,0,-1), "mm"))),
      #4.2.3
      NULL,
      ncol = 1,
      rel_heights = c(0.1, 1, 0.1)
    ),
    nrow = 1,
    rel_widths = c(0.6,  0.5), #4._
    align = 'h', axis = 'bt'
  ),

  ncol = 1,
  rel_heights = c(0.05, 0.95, 0.05, 1)
)

plt_fig <- cowplot::ggdraw(plt_fig)+cowplot::draw_plot_label(c("A","B","C","D"), x=c(0,0,0.25,0.52), y=c(1,rep(0.52,3)))
plt_fig 
```


```{r, eval=F}
ggsave(here("figures/4_Network_analysis.pdf"), plot=plt_fig, height = 5.5, width = 3.5)
```

